{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Closed-Loop Evaluation\n",
    "In this notebook you are going to evaluate Urban Driver to control the SDV with a protocol named *closed-loop* evaluation.\n",
    "\n",
    "**Note: this notebook assumes you've already run the [training notebook](./train.ipynb) and stored your model successfully (or that you have stored a pre-trained one).**\n",
    "\n",
    "**Note: for a detailed explanation of what closed-loop evaluation (CLE) is, please refer to our [planning notebook](../planning/closed_loop_test.ipynb)**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from prettytable import PrettyTable\n",
    "\n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.data import LocalDataManager, ChunkedDataset\n",
    "\n",
    "from l5kit.dataset import EgoDatasetVectorized\n",
    "from l5kit.vectorization.vectorizer_builder import build_vectorizer\n",
    "\n",
    "from l5kit.simulation.dataset import SimulationConfig\n",
    "from l5kit.simulation.unroll import ClosedLoopSimulator\n",
    "from l5kit.cle.closed_loop_evaluator import ClosedLoopEvaluator, EvaluationPlan\n",
    "from l5kit.cle.metrics import (CollisionFrontMetric, CollisionRearMetric, CollisionSideMetric,\n",
    "                               DisplacementErrorL2Metric, DistanceToRefTrajectoryMetric)\n",
    "from l5kit.cle.validators import RangeValidator, ValidationCountingAggregator\n",
    "\n",
    "from l5kit.visualization.visualizer.zarr_utils import simulation_out_to_visualizer_scene\n",
    "from l5kit.visualization.visualizer.visualizer import visualize\n",
    "from bokeh.io import output_notebook, show\n",
    "from l5kit.data import MapAPI\n",
    "\n",
    "from collections import defaultdict\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare data path and load cfg\n",
    "\n",
    "By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.\n",
    "\n",
    "Then, we load our config file with relative paths and other configurations (rasteriser, training params ...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set env variable for data\n",
    "os.environ[\"L5KIT_DATA_FOLDER\"] = \"/tmp/l5kit_data\"\n",
    "dm = LocalDataManager(None)\n",
    "# get config\n",
    "cfg = load_config_data(\"./config.yaml\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"/tmp/urban_driver.pt\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = torch.load(model_path).to(device)\n",
    "model = model.eval()\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the evaluation data\n",
    "Differently from training and open loop evaluation, this setting is intrinsically sequential. As such, we won't be using any of PyTorch's parallelisation functionalities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== INIT DATASET\n",
    "eval_cfg = cfg[\"val_data_loader\"]\n",
    "eval_zarr = ChunkedDataset(dm.require(eval_cfg[\"key\"])).open()\n",
    "vectorizer = build_vectorizer(cfg, dm)\n",
    "eval_dataset = EgoDatasetVectorized(cfg, eval_zarr, vectorizer)\n",
    "print(eval_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define some simulation properties\n",
    "We define here some common simulation properties such as the length of the simulation and how many scene to simulate.\n",
    "\n",
    "**NOTE: these properties have a significant impact on the execution time. We suggest you to increase them only if your setup includes a GPU.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_scenes_to_unroll = 10\n",
    "num_simulation_steps = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Closed-loop simulation\n",
    "\n",
    "We define a closed-loop simulation that drives the SDV for `num_simulation_steps` steps while using the log-replayed agents.\n",
    "\n",
    "Then, we unroll the selected scenes.\n",
    "The simulation output contains all the information related to the scene, including the annotated and simulated positions, states, and trajectories of the SDV and the agents.  \n",
    "If you want to know more about what the simulation output contains, please refer to the source code of the class `SimulationOutput`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==== DEFINE CLOSED-LOOP SIMULATION\n",
    "sim_cfg = SimulationConfig(use_ego_gt=False, use_agents_gt=True, disable_new_agents=True,\n",
    "                           distance_th_far=500, distance_th_close=50, num_simulation_steps=num_simulation_steps,\n",
    "                           start_frame_index=0, show_info=True)\n",
    "\n",
    "sim_loop = ClosedLoopSimulator(sim_cfg, eval_dataset, device, model_ego=model, model_agents=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==== UNROLL\n",
    "scenes_to_unroll = list(range(0, len(eval_zarr.scenes), len(eval_zarr.scenes)//num_scenes_to_unroll))\n",
    "sim_outs = sim_loop.unroll(scenes_to_unroll)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Closed-loop metrics\n",
    "\n",
    "**Note: for a detailed explanation of CLE metrics, please refer again to our [planning notebook](../planning/closed_loop_test.ipynb)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = [DisplacementErrorL2Metric(),\n",
    "           DistanceToRefTrajectoryMetric(),\n",
    "           CollisionFrontMetric(),\n",
    "           CollisionRearMetric(),\n",
    "           CollisionSideMetric()]\n",
    "\n",
    "validators = [RangeValidator(\"displacement_error_l2\", DisplacementErrorL2Metric, max_value=30),\n",
    "              RangeValidator(\"distance_ref_trajectory\", DistanceToRefTrajectoryMetric, max_value=4),\n",
    "              RangeValidator(\"collision_front\", CollisionFrontMetric, max_value=0),\n",
    "              RangeValidator(\"collision_rear\", CollisionRearMetric, max_value=0),\n",
    "              RangeValidator(\"collision_side\", CollisionSideMetric, max_value=0)]\n",
    "\n",
    "intervention_validators = [\"displacement_error_l2\",\n",
    "                           \"distance_ref_trajectory\",\n",
    "                           \"collision_front\",\n",
    "                           \"collision_rear\",\n",
    "                           \"collision_side\"]\n",
    "\n",
    "cle_evaluator = ClosedLoopEvaluator(EvaluationPlan(metrics=metrics,\n",
    "                                                   validators=validators,\n",
    "                                                   composite_metrics=[],\n",
    "                                                   intervention_validators=intervention_validators))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantitative evaluation\n",
    "\n",
    "We can now compute the metric evaluation, collect the results and aggregate them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cle_evaluator.evaluate(sim_outs)\n",
    "validation_results = cle_evaluator.validation_results()\n",
    "agg = ValidationCountingAggregator().aggregate(validation_results)\n",
    "cle_evaluator.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reporting errors from the closed-loop\n",
    "\n",
    "We can now report the metrics and plot them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fields = [\"metric\", \"value\"]\n",
    "table = PrettyTable(field_names=fields)\n",
    "\n",
    "values = []\n",
    "names = []\n",
    "\n",
    "for metric_name in agg:\n",
    "    table.add_row([metric_name, agg[metric_name].item()])\n",
    "    values.append(agg[metric_name].item())\n",
    "    names.append(metric_name)\n",
    "\n",
    "print(table)\n",
    "\n",
    "plt.bar(np.arange(len(names)), values)\n",
    "plt.xticks(np.arange(len(names)), names, rotation=60, ha='right')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Qualitative evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualise the closed-loop\n",
    "\n",
    "We can visualise the scenes we have obtained previously. \n",
    "\n",
    "**The policy is now in full control of the SDV as this moves through the annotated scene.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "output_notebook()\n",
    "mapAPI = MapAPI.from_cfg(dm, cfg)\n",
    "for sim_out in sim_outs: # for each scene\n",
    "    vis_in = simulation_out_to_visualizer_scene(sim_out, mapAPI)\n",
    "    show(visualize(sim_out.scene_id, vis_in))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
